{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "ece8514e",
   "metadata": {},
   "source": [
    "#### Set environment variables in [.env](.env) for LLM API calling"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "678ed8db",
   "metadata": {},
   "source": [
    "### Import Dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f1fb3d81-16b6-4b8c-a028-880fdce5e14a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.insert(0, \"../../\")\n",
    "import promptwizard\n",
    "from promptwizard.glue.promptopt.instantiate import GluePromptOpt\n",
    "from promptwizard.glue.promptopt.techniques.common_logic import DatasetSpecificProcessing\n",
    "from promptwizard.glue.common.utils.file import save_jsonlist\n",
    "from typing import Any\n",
    "from tqdm import tqdm\n",
    "import json\n",
    "import os\n",
    "from azure.identity import get_bearer_token_provider, AzureCliCredential\n",
    "from openai import AzureOpenAI\n",
    "\n",
    "from dotenv import load_dotenv\n",
    "load_dotenv(override = True)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dc9b746c",
   "metadata": {},
   "source": [
    "### Below code can be used for LLM-as-a-judge eval"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "26719362",
   "metadata": {},
   "outputs": [],
   "source": [
    "def extract_between(start, end, text):\n",
    "    \"\"\"\n",
    "    Extracts the substring from 'text' that is between 'start' and 'end' strings.\n",
    "    \n",
    "    Parameters:\n",
    "    - start (str): The starting delimiter string.\n",
    "    - end (str): The ending delimiter string.\n",
    "    - text (str): The text to search within.\n",
    "    \n",
    "    Returns:\n",
    "    - str: The extracted substring between the start and end delimiters.\n",
    "    \"\"\"\n",
    "    start_index = text.find(start)\n",
    "    if start_index == -1:\n",
    "        return '' \n",
    "    \n",
    "    start_index += len(start)\n",
    "    \n",
    "    end_index = text.find(end, start_index)\n",
    "    if end_index == -1:\n",
    "        return ''  \n",
    "    return text[start_index:end_index]\n",
    "\n",
    "def call_api(messages):\n",
    "    \n",
    "    token_provider = get_bearer_token_provider(\n",
    "            AzureCliCredential(), \"https://cognitiveservices.azure.com/.default\"\n",
    "        )\n",
    "    client = AzureOpenAI(\n",
    "        api_version=\"<OPENAI_API_VERSION>\",\n",
    "        azure_endpoint=\"<AZURE_ENDPOINT>\",\n",
    "        azure_ad_token_provider=token_provider\n",
    "        )\n",
    "    response = client.chat.completions.create(\n",
    "        model=\"<MODEL_DEPLOYMENT_NAME>\",\n",
    "        messages=messages,\n",
    "        temperature=0.0,\n",
    "    )\n",
    "    prediction = response.choices[0].message.content\n",
    "    return prediction\n",
    "\n",
    "def llm_eval(predicted_answer,gt_answer):\n",
    "    \n",
    "    EVAL_PROMPT = f\"\"\"Given the Predicted_Answer and Reference_Answer, compare them and check they mean the same.\n",
    "                    If they mean the same then return True between <ANS_START> and <ANS_END> tags , \n",
    "                    If they differ in the meaning then return False between <ANS_START> and <ANS_END> tags \n",
    "                    Following are the given :\n",
    "                    Predicted_Answer: {predicted_answer}\n",
    "                    Reference_Answer: {gt_answer}\"\"\"\n",
    "    messages = [\n",
    "        {\"role\": \"system\", \"content\": \"\"},\n",
    "        {\"role\": \"user\", \"content\": EVAL_PROMPT}\n",
    "    ]\n",
    "\n",
    "    response = call_api(messages)\n",
    "    final_judgement = extract_between(start=\"<ANS_START>\", end=\"<ANS_END>\", text=response)\n",
    "    return final_judgement == \"True\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4a5084d7",
   "metadata": {},
   "source": [
    "### Create a dataset specific class and define the required functions "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "5f325d33",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "llm_as_judge_eval = True\n",
    "\n",
    "class BBH(DatasetSpecificProcessing):\n",
    "\n",
    "    def dataset_to_jsonl(self, dataset_jsonl: str, **kwargs: Any) -> None:\n",
    "        def extract_answer_from_output(completion):\n",
    "\n",
    "                return completion\n",
    "\n",
    "        examples_set = []\n",
    "\n",
    "        for _, sample in tqdm(enumerate(kwargs[\"dataset\"]), desc=\"Evaluating samples\"):\n",
    "            example = {\n",
    "              DatasetSpecificProcessing.QUESTION_LITERAL: sample['question'],\n",
    "              DatasetSpecificProcessing.ANSWER_WITH_REASON_LITERAL: sample['answer'],\n",
    "              DatasetSpecificProcessing.FINAL_ANSWER_LITERAL: extract_answer_from_output(sample[\"answer\"])\n",
    "            }\n",
    "            examples_set.append(example)\n",
    "\n",
    "        save_jsonlist(dataset_jsonl, examples_set, \"w\")\n",
    "\n",
    "    def extract_final_answer(self, answer: str):\n",
    "        \n",
    "        final_answer = extract_between(text=answer,start=\"<ANS_START>\",end=\"<ANS_END>\")\n",
    "        return final_answer\n",
    "    \n",
    "    def access_answer(self, llm_output: str, gt_answer: str):\n",
    "\n",
    "        if llm_as_judge_eval:\n",
    "            predicted_answer = self.extract_final_answer(llm_output)\n",
    "            is_correct = False\n",
    "            if llm_eval(predicted_answer,gt_answer):\n",
    "                is_correct = True\n",
    "        else:\n",
    "            predicted_answer = self.extract_final_answer(llm_output)\n",
    "            is_correct = False\n",
    "            if predicted_answer and (predicted_answer.lower() == gt_answer.lower()):\n",
    "                is_correct = True\n",
    "\n",
    "            return is_correct, predicted_answer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "f384eb57",
   "metadata": {},
   "outputs": [],
   "source": [
    "bbh_processor = BBH()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ec7d1396",
   "metadata": {},
   "source": [
    "### Load and save the dataset . \n",
    "Set the ```dataset_to_run``` variable to choose 1 among the 19 datasets of BBII to run the optimization on"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "976681bd-4f43-4dbc-947e-cdb94d4824f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "if not os.path.exists(\"data\"):\n",
    "    os.mkdir(\"data\")\n",
    "dataset_list = ['informal_to_formal','letters_list','negation','orthography_starts_with','rhymes','second_word_letter','sum','diff','sentence_similarity','taxonomy_animal','auto_categorization','object_counting','odd_one_out','antonyms','word_unscrambling','cause_and_effect','common_concept','word_sorting','synonyms']\n",
    "\n",
    "# Set the dataset on which to run optimization out of the 19 \n",
    "dataset_to_run = 'second_word_letter'\n",
    "\n",
    "if not os.path.exists(\"data/\"+dataset_to_run):\n",
    "    os.mkdir(\"data/\"+dataset_to_run)\n",
    "    \n",
    "os.system(\"git clone https://github.com/xqlin98/INSTINCT\")\n",
    "\n",
    "\n",
    "for mode in ['execute','induce']:\n",
    "    for dataset in dataset_list:\n",
    "\n",
    "        if dataset_to_run == dataset:\n",
    "            data_list = []\n",
    "\n",
    "            file_path = 'INSTINCT/Induction/experiments/data/instruction_induction/raw/'+mode+'/'+dataset+'.json'  \n",
    "            with open(file_path, 'r') as file:\n",
    "                data = json.load(file)\n",
    "            \n",
    "            save_file_path = 'test.jsonl'\n",
    "            if mode == 'execute':\n",
    "                save_file_path = 'train.jsonl'\n",
    "\n",
    "            for key,sample in data['examples'].items():\n",
    "                task = dataset\n",
    "                if(task == 'cause_and_effect'):\n",
    "                    cause = sample[\"cause\"]\n",
    "                    effect = sample[\"effect\"]\n",
    "                    import random\n",
    "                    pair = [cause, effect]\n",
    "                    random.shuffle(pair)\n",
    "                    question = f\"Sentence 1: {pair[0]} Sentence 2: {pair[1]}\",\n",
    "                    answer = cause,\n",
    "                elif(task == 'antonyms'):\n",
    "                    \n",
    "                        question = sample[\"input\"],\n",
    "                        answer = sample[\"output\"],\n",
    "\n",
    "                elif(task == 'common_concept'):\n",
    "                    concept = sample[\"concept\"]\n",
    "                    items = sample[\"items\"]\n",
    "                    input = \", \".join(items)\n",
    "                    question = f\"Objects: {input}\"\n",
    "                    answer = f\"{concept}\"\n",
    "\n",
    "                elif(task == 'diff'):\n",
    "                    input = sample[\"input\"]\n",
    "                    output = sample[\"output\"]\n",
    "                    question = f\"{input}\"\n",
    "                    answer = f\"{output}\"\n",
    "\n",
    "                elif(task == 'informal_to_formal'):\n",
    "                    informal = sample[\"input\"]\n",
    "                    formal = sample[\"output\"]\n",
    "                    question = f\"{informal}\"\n",
    "                    answer = f\"{formal}\"\n",
    "\n",
    "                elif(task == 'synonyms' or task == 'word_unscrambling' or task == 'word_sorting' or task == 'letters_list' or task == 'negation' or task == 'orthography_starts_with' or task == 'second_word_letter' or task == 'sentence_similarity' or task == 'sum' or task == 'taxonomy_animal' or task == 'auto_categorization' or task == 'object_counting' or task == 'odd_one_out'):\n",
    "                    informal = sample[\"input\"]\n",
    "                    formal = sample[\"output\"] \n",
    "                    question = f\"{informal}\"\n",
    "                    answer = f\"{formal}\"\n",
    "\n",
    "                elif(task == 'rhymes'):\n",
    "                    input = sample[\"input\"]\n",
    "                    output = sample[\"other_rhymes\"]\n",
    "                    output = \", \".join(output)\n",
    "                    question = f\"{input}\"\n",
    "                    answer = f\"{output}\"\n",
    "            \n",
    "                data_list.append({\"question\":question,\"answer\":answer})\n",
    "            bbh_processor.dataset_to_jsonl(\"data/\"+dataset +\"/\"+save_file_path, dataset=data_list)\n",
    "\n",
    "os.system(\"rm -r INSTINCT\")\n",
    "           "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fe28a967",
   "metadata": {},
   "source": [
    "### Set paths"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "f43482f1-3e10-4cf7-8ea6-ff42c04067a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_file_name = os.path.join(\"data/\"+dataset_to_run, \"train.jsonl\")\n",
    "test_file_name = os.path.join(\"data/\"+dataset_to_run, \"test.jsonl\")\n",
    "path_to_config = \"configs\"\n",
    "llm_config_path = os.path.join(path_to_config, \"llm_config.yaml\")\n",
    "promptopt_config_path = os.path.join(path_to_config, \"promptopt_config.yaml\")\n",
    "setup_config_path = os.path.join(path_to_config, \"setup_config.yaml\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "75ac5780",
   "metadata": {},
   "source": [
    "### Create an object for calling prompt optimization and inference functionalities"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8af4246f-db32-4b37-a73a-f9e2e5125d09",
   "metadata": {},
   "outputs": [],
   "source": [
    "gp = GluePromptOpt(promptopt_config_path,\n",
    "                   setup_config_path,\n",
    "                   train_file_name,\n",
    "                   bbh_processor)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9a26af0d",
   "metadata": {},
   "source": [
    "### Call prompt optmization function\n",
    "1. ```use_examples``` can be used when there are training samples and a mixture of real and synthetic in-context examples are required in the final prompt. When set to ```False``` all the in-context examples will be real\n",
    "2. ```generate_synthetic_examples``` can be used when there are no training samples and we want to generate synthetic examples \n",
    "3. ```run_without_train_examples``` can be used when there are no training samples and in-context examples are not required in the final prompt "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "573c6151-2c03-45d9-9904-1724a1e20f1b",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Function call to generate optimal prompt and expert profile \n",
    "best_prompt, expert_profile = gp.get_best_prompt(use_examples=True,run_without_train_examples=False,generate_synthetic_examples=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ef923b11",
   "metadata": {},
   "source": [
    "### Save the optimized prompt and expert profile"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "34a716af-0d77-4c7d-b1c2-6438d66096ce",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import pickle \n",
    "if not os.path.exists(\"results\"):\n",
    "    os.system(\"mkdir results\")\n",
    "\n",
    "with open(\"results/best_prompt.pkl\", 'wb') as f:\n",
    "    pickle.dump(best_prompt, f)\n",
    "with open(\"results/expert_profile.pkl\", 'wb') as f:\n",
    "    pickle.dump(expert_profile, f)\n",
    "\n",
    "print(f\"Best prompt: {best_prompt} \\nExpert profile: {expert_profile}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1942c67e",
   "metadata": {},
   "source": [
    "### Evaluate the optimized prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c49b5711-82dd-4d18-8cd4-ee447cf8d74c",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "gp.EXPERT_PROFILE = expert_profile\n",
    "gp.BEST_PROMPT = best_prompt\n",
    "\n",
    "# Function call to evaluate the prompt\n",
    "accuracy = gp.evaluate(test_file_name)\n",
    "\n",
    "print(f\"Final Accuracy: {accuracy}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "PromptWizard",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
